package com.bigdata.spark.projectone

import org.apache.spark.Partitioner

import scala.collection.mutable

/**
 * @author Gerry chan
 * @version 1.0
 * 自定义分区器
 */
class ProjectPartitioner(projects:Array[String])  extends Partitioner {
  //用来存储学科与分区号
  private val project_parnum:mutable.HashMap[String,Int] = new mutable.HashMap[String, Int]()

  //计数器，用来生成分区号
  var n = 0;
  for (p <- projects) {
    project_parnum+=(p->n)
    n+=1
  }
  //获取分区数
  override def numPartitions: Int = {
    projects.size
  }

  //获取分区号
  override def getPartition(key: Any): Int = {
    project_parnum.getOrElse(key.toString, 0)
  }
}
